import torch
import torch.nn as nn
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 2, 3)
        self.pool1 = nn.MaxPool2d(2, 2)
    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        return x
def farward_hook(module, data_input, data_output):
    fmap_block.append(data_output)
    input_block.append(data_input)
if __name__ == "__main__":
    # 初始化网络
    net = Net()
    net.conv1.weight[0].fill_(1)
    net.conv1.weight[1].fill_(2)
    net.conv1.bias.data.zero_()
    # 注册hook
    fmap_block = list()
    input_block = list()
    net.conv1.register_forward_hook(farward_hook)
    # inference
    fake_img = torch.ones((1, 1, 4, 4))   # batch size * channel * H * W
    output = net(fake_img)
    # 观察
    print("output shape: {}\noutput value: {}\n".format(output.shape, output))
    print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))
    print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))